import os
import torch
from torchvision import transforms
from torchvision.datasets import CocoCaptions
from torch.utils.data import DataLoader
from diffusers import StableDiffusionPipeline
from PIL import Image

# Paths
coco_root = './data/datasets/coco/'
ann_file = os.path.join(coco_root, 'annotations/captions_train2014.json')
img_folder = os.path.join(coco_root, 'train2014')
output_dir = './images'
os.makedirs(output_dir, exist_ok=True)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])
coco_dataset = CocoCaptions(root=img_folder, annFile=ann_file, transform=transform)
dataloader = DataLoader(coco_dataset, batch_size=1, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to(device)
pipe.scheduler.set_timesteps(25)
image_ids = coco_dataset.coco.getImgIds()

for idx, (image, captions) in enumerate(dataloader):
    if idx >= 10000:
        break
    image_id = image_ids[idx]
    caption_batch = captions[0][:5]

    for i, caption_text in enumerate(caption_batch):
        with torch.no_grad():
            generated_image = pipe(caption_text, height=224, width=224).images[0]
        save_path = os.path.join(output_dir, f"{image_id}_{i + 1}.png")
        generated_image.save(save_path)

    if idx % 100 == 0:
        print(f"Processed {idx + 1} images")

print("Done generating images.")